import os
import random
import shutil

def sample_labels_and_images(txt_folder, image_folder, output_folder, target_labels, sample_ratio):
    # 创建输出文件夹
    os.makedirs(output_folder, exist_ok=True)

    # 获取txt文件夹中的所有文件
    txt_files = [f for f in os.listdir(txt_folder) if f.endswith('.txt')]

    # 计算需要抽取的数量
    num_samples = int(len(txt_files) * sample_ratio)

    # 随机抽取样本
    sampled_files = random.sample(txt_files, num_samples)

    # 遍历抽取的文件
    for sampled_file in sampled_files:
        # 构建对应的图片文件名
        image_file = sampled_file[:-4] + '.jpg'

        # 构建文件路径
        sampled_txt_path = os.path.join(txt_folder, sampled_file)
        sampled_image_path = os.path.join(image_folder, image_file)

        # 构建输出路径
        output_txt_path = os.path.join(output_folder, sampled_file)
        output_image_path = os.path.join(output_folder, image_file)

        # 复制标签文件和图片文件到输出文件夹
        shutil.copy(sampled_txt_path, output_txt_path)
        shutil.copy(sampled_image_path, output_image_path)

    # 打印抽取结果
    print(f'Sampled {num_samples} labels and images.')

if __name__=="__main__":

    txt_folder = 'D:/datasets/visDrone/VisDrone2019-DET-val/labels'  # txt文件夹路径
    image_folder = 'D:/datasets/visDrone/VisDrone2019-DET-val/images'  # 图片文件夹路径
    output_folder = 'D:/datasets/visDrone/VisDrone2019-DET-val/new_label_0.01'  # 输出文件夹路径
    target_labels = ['0', '1','3','4','5','8']  # droneYOLO抽取的标签类别
    sample_ratio = 0.02 # 抽取比例为1%

    # 执行抽取操作
    sample_labels_and_images(txt_folder, image_folder, output_folder, target_labels, sample_ratio)